package com.stylefeng.guns.modular.prediction.util;

import com.stylefeng.guns.common.exception.BussinessException;
import com.stylefeng.guns.core.util.DoubleUtil;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;

import static com.stylefeng.guns.common.exception.BizExceptionEnum.CALCULATE_ERROR;
import static java.util.stream.Collectors.toList;

public class OneFactorForecasting {

	/**
	 * 单因素预测
	 * @param actualList
	 * @return
	 */
	public static List<Double> forecast(List<Integer> actualList) {
		return forecast(actualList, 1);
	}

    /**
     *
     * @param actualList
     * @param after
     * @return forecastingList
     */
    public static List<Double> forecast(List<Integer> actualList , int after) {
		if(!actualList.stream().filter(actual -> actual != 0).findAny().isPresent()){
			int size = after + actualList.size();
			List<Double> list = new ArrayList<>();
			for (int i = 0; i < size; i++) {
				list.add(0d);
			}
			return list;
		}

        List<Integer> Q1 = IntStream.rangeClosed(1, actualList.size())
				.boxed()
                .map(n ->
                        actualList.stream().limit(n).reduce(0, (a, b) -> a + b)
                )
                .collect(toList());

        Integer rowNum = Q1.size();

        Matrix B = DenseMatrix.Factory.zeros(rowNum - 1, 2);
        Matrix Y = DenseMatrix.Factory.zeros(rowNum - 1, 1);

        for (int i = 1; i < rowNum; i++) {
            Y.setAsDouble(actualList.get(i), i - 1, 0);
            for (int j = 0; j < 2; j++) {
                if (j == 1) {
                    B.setAsDouble(1, i - 1, j);
                } else {
                    B.setAsDouble(DoubleUtil.mul(-0.5, DoubleUtil.add(Q1.get(i), Q1.get(i - 1))), i - 1, j);
                }
            }
        }


        Matrix Bt = B.transpose();
		if (Bt.mtimes(B).det() == 0) {
			throw new BussinessException(CALCULATE_ERROR);
		}

        Matrix ret = Bt.mtimes(B).inv().mtimes(Bt).mtimes(Y);
        Double a = ret.getAsDouble(0, 0);
        Double b = ret.getAsDouble(1, 0);


        List<Double> forecastingList = new ArrayList<>();
        forecastingList.add(actualList.get(0).doubleValue());
        for (int k = 1; k < rowNum + after; k++) {
            Double forecastingVal = DoubleUtil.sub(
                    cal(a, b, actualList.get(0), k),
                    cal(a, b, actualList.get(0), k - 1)
            );
            forecastingList.add(forecastingVal);
        }
        return forecastingList;
    }

    private static Double cal(Double a, Double b, Integer q0, Integer k) {
        return DoubleUtil.add(
                DoubleUtil.mul(
                        DoubleUtil.sub(
                                q0,
                                DoubleUtil.div(b, a)
                        ),
                        Math.pow(
                                Math.E,
                                DoubleUtil.mul(
                                        DoubleUtil.sub(0, a),
                                        k
                                )
                        )
                ),
                DoubleUtil.div(b, a)
        );
    }

}
